import os
import h5py
import torch
import numpy as np
from tqdm import tqdm
import sys
sys.path.append("./mindeye2_src")
from mindeye2_src.models import Clipper

# 设置设备
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# 加载 CLIP 模型
clipper = Clipper("ViT-L/14", device=device, hidden_state=True, norm_embs=True)
clipper = clipper.to(device)

# 加载图片数据
data_path = './dataset'
image_file = h5py.File(f'{data_path}/coco_images_224_float16.hdf5', 'r')
images = image_file['images']

# 创建保存 CLIP 向量的文件
output_file = h5py.File(f'{data_path}/clip_embeddings.hdf5', 'w')
# 注意：这里保存的是 3D 张量，形状为 (len(images), 257, 768)
output_embeddings = output_file.create_dataset('embeddings', (len(images), 257, 768), dtype=np.float16)

# 批量计算 CLIP 向量
batch_size = 32
num_batches = int(np.ceil(len(images) / batch_size))

for i in tqdm(range(num_batches)):
    start = i * batch_size
    end = min((i + 1) * batch_size, len(images))
    batch_images = torch.tensor(images[start:end]).to(device)
    with torch.no_grad():
        embeddings = clipper.embed_image(batch_images).cpu().numpy().astype(np.float16)
    output_embeddings[start:end] = embeddings

# 关闭文件
image_file.close()
output_file.close()
print("CLIP embeddings saved to clip_embeddings.hdf5")